{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hugging Face Transformers 微调语言模型-问答任务\n",
    "\n",
    "我们已经学会使用 Pipeline 加载支持问答任务的预训练模型，本教程代码将展示如何微调训练一个支持问答任务的模型。\n",
    "\n",
    "**注意：微调后的模型仍然是通过提取上下文的子串来回答问题的，而不是生成新的文本。**\n",
    "\n",
    "### 模型执行问答效果示例\n",
    "\n",
    "![Widget inference representing the QA task](docs/images/question_answering.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "zVvslsfMIrIh"
   },
   "outputs": [],
   "source": [
    "# 根据你使用的模型和GPU资源情况，调整以下关键参数\n",
    "squad_v2 = True\n",
    "model_checkpoint = \"distilbert-base-uncased\"\n",
    "batch_size = 16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "whPRbBNbIrIl"
   },
   "source": [
    "## 下载数据集\n",
    "\n",
    "在本教程中，我们将使用[斯坦福问答数据集(SQuAD）](https://rajpurkar.github.io/SQuAD-explorer/)。\n",
    "\n",
    "### SQuAD 数据集\n",
    "\n",
    "**斯坦福问答数据集(SQuAD)** 是一个阅读理解数据集，由众包工作者在一系列维基百科文章上提出问题组成。每个问题的答案都是相应阅读段落中的文本片段或范围，或者该问题可能无法回答。\n",
    "\n",
    "SQuAD2.0将SQuAD1.1中的10万个问题与由众包工作者对抗性地撰写的5万多个无法回答的问题相结合，使其看起来与可回答的问题类似。要在SQuAD2.0上表现良好，系统不仅必须在可能时回答问题，还必须确定段落中没有支持任何答案，并放弃回答。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "IreSlFmlIrIm"
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 270,
     "referenced_widgets": [
      "69caab03d6264fef9fc5649bffff5e20",
      "3f74532faa86412293d90d3952f38c4a",
      "50615aa59c7247c4804ca5cbc7945bd7",
      "fe962391292a413ca55dc932c4279fa7",
      "299f4b4c07654e53a25f8192bd1d7bbd",
      "ad04ed1038154081bbb0c1444784dcc2",
      "7c667ad22b5740d5a6319f1b1e3a8097",
      "46c2b043c0f84806978784a45a4e203b",
      "80e2943be35f46eeb24c8ab13faa6578",
      "de5956b5008d4fdba807bae57509c393",
      "931db1f7a42f4b46b7ff8c2e1262b994",
      "6c1db72efff5476e842c1386fadbbdba",
      "ccd2f37647c547abb4c719b75a26f2de",
      "d30a66df5c0145e79693e09789d96b81",
      "5fa26fc336274073abbd1d550542ee33",
      "2b34de08115d49d285def9269a53f484",
      "d426be871b424affb455aeb7db5e822e",
      "160bf88485f44f5cb6eaeecba5e0901f",
      "745c0d47d672477b9bb0dae77b926364",
      "d22ab78269cd4ccfbcf70c707057c31b",
      "d298eb19eeff453cba51c2804629d3f4",
      "a7204ade36314c86907c562e0a2158b8",
      "e35d42b2d352498ca3fc8530393786b2",
      "75103f83538d44abada79b51a1cec09e",
      "f6253931d90543e9b5fd0bb2d615f73a",
      "051aa783ff9e47e28d1f9584043815f5",
      "0984b2a14115454bbb009df71c1cf36f",
      "8ab9dfce29854049912178941ef1b289",
      "c9de740e007141958545e269372780a4",
      "cbea68b25d6d4ba09b2ce0f27b1726d5",
      "5781fc45cf8d486cb06ed68853b2c644",
      "d2a92143a08a4951b55bab9bc0a6d0d3",
      "a14c3e40e5254d61ba146f6ec88eae25",
      "c4ffe6f624ce4e978a0d9b864544941a",
      "1aca01c1d8c940dfadd3e7144bb35718",
      "9fbbaae50e6743f2aa19342152398186",
      "fea27ca6c9504fc896181bc1ff5730e5",
      "940d00556cb849b3a689d56e274041c2",
      "5cdf9ed939fb42d4bf77301c80b8afca",
      "94b39ccfef0b4b08bf2fb61bb0a657c1",
      "9a55087c85b74ea08b3e952ac1d73cbe",
      "2361ab124daf47cc885ff61f2899b2af",
      "1a65887eb37747ddb75dc4a40f7285f2",
      "3c946e2260704e6c98593136bd32d921",
      "50d325cdb9844f62a9ecc98e768cb5af",
      "aa781f0cfe454e9da5b53b93e9baabd8",
      "6bb68d3887ef43809eb23feb467f9723",
      "7e29a8b952cf4f4ea42833c8bf55342f",
      "dd5997d01d8947e4b1c211433969b89b",
      "2ace4dc78e2f4f1492a181bcd63304e7",
      "bbee008c2791443d8610371d1f16b62b",
      "31b1c8a2e3334b72b45b083688c1a20c",
      "7fb7c36adc624f7dbbcb4a831c1e4f63",
      "0b7c8f1939074794b3d9221244b1344d",
      "a71908883b064e1fbdddb547a8c41743",
      "2f5223f26c8541fc87e91d2205c39995"
     ]
    },
    "id": "s_AY1ATSIrIq",
    "outputId": "fd0578d1-8895-443d-b56f-5908de9f1b6b"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c55fc01c398d444b98d188566136c3d5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme:   0%|          | 0.00/8.92k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e9e532c28fcd427aa37ed8fde293e16b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/16.4M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3a880e96da2b46b7a36d121faf42c24a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/1.35M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3e39d74f45bb4f6d9a7b02e6555b1a3a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/130319 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f0a87f6243a416c9e273ca6044136c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating validation split:   0%|          | 0/11873 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "datasets = load_dataset(\"squad_v2\" if squad_v2 else \"squad\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RzfPtOMoIrIu"
   },
   "source": [
    "The `datasets` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "GWiVUF0jIrIv",
    "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'title', 'context', 'question', 'answers'],\n",
       "        num_rows: 130319\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['id', 'title', 'context', 'question', 'answers'],\n",
       "        num_rows: 11873\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 对比数据集\n",
    "\n",
    "相比快速入门使用的 Yelp 评论数据集，我们可以看到 SQuAD 训练和测试集都新增了用于上下文、问题以及问题答案的列：\n",
    "\n",
    "**YelpReviewFull Dataset：**\n",
    "\n",
    "```json\n",
    "\n",
    "DatasetDict({\n",
    "    train: Dataset({\n",
    "        features: ['label', 'text'],\n",
    "        num_rows: 650000\n",
    "    })\n",
    "    test: Dataset({\n",
    "        features: ['label', 'text'],\n",
    "        num_rows: 50000\n",
    "    })\n",
    "})\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "X6HrpprwIrIz",
    "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'id': '56be85543aeaaa14008c9063',\n",
       " 'title': 'Beyoncé',\n",
       " 'context': 'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny\\'s Child. Managed by her father, Mathew Knowles, the group became one of the world\\'s best-selling girl groups of all time. Their hiatus saw the release of Beyoncé\\'s debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles \"Crazy in Love\" and \"Baby Boy\".',\n",
       " 'question': 'When did Beyonce start becoming popular?',\n",
       " 'answers': {'text': ['in the late 1990s'], 'answer_start': [269]}}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets[\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 从上下文中组织回复内容\n",
    "\n",
    "我们可以看到答案是通过它们在文本中的起始位置（这里是第515个字符）以及它们的完整文本表示的，这是上面提到的上下文的子字符串。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "i3j8APAoIrI3"
   },
   "outputs": [],
   "source": [
    "from datasets import ClassLabel, Sequence\n",
    "import random\n",
    "import pandas as pd\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "def show_random_elements(dataset, num_examples=10):\n",
    "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
    "    picks = []\n",
    "    for _ in range(num_examples):\n",
    "        pick = random.randint(0, len(dataset)-1)\n",
    "        while pick in picks:\n",
    "            pick = random.randint(0, len(dataset)-1)\n",
    "        picks.append(pick)\n",
    "    \n",
    "    df = pd.DataFrame(dataset[picks])\n",
    "    for column, typ in dataset.features.items():\n",
    "        if isinstance(typ, ClassLabel):\n",
    "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
    "        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):\n",
    "            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])\n",
    "    display(HTML(df.to_html()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "SZy5tRB_IrI7",
    "outputId": "ba8f2124-e485-488f-8c0c-254f34f24f13"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>title</th>\n",
       "      <th>context</th>\n",
       "      <th>question</th>\n",
       "      <th>answers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>57289df84b864d1900164ad2</td>\n",
       "      <td>History_of_India</td>\n",
       "      <td>A Turco-Mongol conqueror in Central Asia, Timur (Tamerlane), attacked the reigning Sultan Nasir-u Din Mehmud of the Tughlaq Dynasty in the north Indian city of Delhi. The Sultan's army was defeated on 17 December 1398. Timur entered Delhi and the city was sacked, destroyed, and left in ruins, after Timur's army had killed and plundered for three days and nights. He ordered the whole city to be sacked except for the sayyids, scholars, and the \"other Muslims\" (artists); 100,000 war prisoners were put to death in one day. The Sultanate suffered significantly from the sacking of Delhi revived briefly under the Lodi Dynasty, but it was a shadow of the former.</td>\n",
       "      <td>What Turko-Mongol attacked and defeated the Sultan of Tughlaq dynasty?</td>\n",
       "      <td>{'text': ['Timur'], 'answer_start': [42]}</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>5acd533407355d001abf3d51</td>\n",
       "      <td>Raleigh,_North_Carolina</td>\n",
       "      <td>After the Civil War began, Governor Zebulon Baird Vance ordered the construction of breastworks around the city as protection from Union troops. During General Sherman's Carolinas Campaign, Raleigh was captured by Union cavalry under the command of General Hugh Judson Kilpatrick on April 13, 1865. As the Confederate cavalry retreated west, the Union soldiers followed, leading to the nearby Battle of Morrisville. The city was spared significant destruction during the War, but due to the economic problems of the post-war period and Reconstruction, with a state economy based on agriculture, it grew little over the next several decades.</td>\n",
       "      <td>When did Hugh Judson Kilpatrick die?</td>\n",
       "      <td>{'text': [], 'answer_start': []}</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>57315596a5e9cc1400cdbe9b</td>\n",
       "      <td>Mosaic</td>\n",
       "      <td>The abundant variety of natural life depicted in the Butrint mosaics celebrates the richness of God’s creation; some elements also have specific connotations. The kantharos vase and vine refer to the eucharist, the symbol of the sacrifice of Christ leading to salvation. Peacocks are symbols of paradise and resurrection; shown eating or drinking from the vase they indicate the route to eternal life. Deer or stags were commonly used as images of the faithful aspiring to Christ: \"As a heart desireth the water brook, so my souls longs for thee, O God.\" Water-birds and fish and other sea-creatures can indicate baptism as well as the members of the Church who are christened.</td>\n",
       "      <td>What was depicted on the Butrint mosaics in abundance?</td>\n",
       "      <td>{'text': ['natural life'], 'answer_start': [24]}</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>5ad4d6755b96ef001a10a2be</td>\n",
       "      <td>Red</td>\n",
       "      <td>Red was also a badge of rank. During the Song dynasty (906–1279), officials of the top three ranks wore purple clothes; those of the fourth and fifth wore bright red; those of the sixth and seventh wore green; and the eighth and ninth wore blue. Red was the color worn by the royal guards of honor, and the color of the carriages of the imperial family. When the imperial family traveled, their servants and accompanying officials carried red and purple umbrellas. Of an official who had talent and ambition, it was said \"he is so red he becomes purple.\"</td>\n",
       "      <td>What dynasty took place from 909-1276?</td>\n",
       "      <td>{'text': [], 'answer_start': []}</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>572c00d6f182dd1900d7c7b2</td>\n",
       "      <td>Tennessee</td>\n",
       "      <td>In February 1861, secessionists in Tennessee's state government—led by Governor Isham Harris—sought voter approval for a convention to sever ties with the United States, but Tennessee voters rejected the referendum by a 54–46% margin. The strongest opposition to secession came from East Tennessee (which later tried to form a separate Union-aligned state). Following the Confederate attack upon Fort Sumter in April and Lincoln's call for troops from Tennessee and other states in response, Governor Isham Harris began military mobilization, submitted an ordinance of secession to the General Assembly, and made direct overtures to the Confederate government. The Tennessee legislature ratified an agreement to enter a military league with the Confederate States on May 7, 1861. On June 8, 1861, with people in Middle Tennessee having significantly changed their position, voters approved a second referendum calling for secession, becoming the last state to do so.</td>\n",
       "      <td>Which region of Tennessee swung in favor of secession in the June 1861 referendum?</td>\n",
       "      <td>{'text': ['Middle Tennessee'], 'answer_start': [812]}</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>5a8cde4cfd22b3001a8d8f75</td>\n",
       "      <td>Memory</td>\n",
       "      <td>Brain areas involved in the neuroanatomy of memory such as the hippocampus, the amygdala, the striatum, or the mammillary bodies are thought to be involved in specific types of memory. For example, the hippocampus is believed to be involved in spatial learning and declarative learning, while the amygdala is thought to be involved in emotional memory. Damage to certain areas in patients and animal models and subsequent memory deficits is a primary source of information. However, rather than implicating a specific area, it could be that damage to adjacent areas, or to a pathway traveling through the area is actually responsible for the observed deficit. Further, it is not sufficient to describe memory, and its counterpart, learning, as solely dependent on specific brain regions. Learning and memory are attributed to changes in neuronal synapses, thought to be mediated by long-term potentiation and long-term depression.</td>\n",
       "      <td>Can you pin point certain areas of the heart to certain memories</td>\n",
       "      <td>{'text': [], 'answer_start': []}</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>56df8e3e38dc42170015203e</td>\n",
       "      <td>Alexander_Graham_Bell</td>\n",
       "      <td>Meanwhile, Elisha Gray was also experimenting with acoustic telegraphy and thought of a way to transmit speech using a water transmitter. On February 14, 1876, Gray filed a caveat with the U.S. Patent Office for a telephone design that used a water transmitter. That same morning, Bell's lawyer filed Bell's application with the patent office. There is considerable debate about who arrived first and Gray later challenged the primacy of Bell's patent. Bell was in Boston on February 14 and did not arrive in Washington until February 26.</td>\n",
       "      <td>What material did Elisha Gray use to convey sound?</td>\n",
       "      <td>{'text': ['water'], 'answer_start': [119]}</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>5726dfaedd62a815002e9381</td>\n",
       "      <td>Madonna_(entertainer)</td>\n",
       "      <td>In mid-2004 Madonna embarked on the Re-Invention World Tour in the U.S., Canada, and Europe. It became the highest-grossing tour of 2004, earning around $120 million and became the subject of her documentary I'm Going to Tell You a Secret. In November 2004, she was inducted into the UK Music Hall of Fame as one of its five founding members, along with The Beatles, Elvis Presley, Bob Marley, and U2. In January 2005, Madonna performed a cover version of the John Lennon song \"Imagine\" at Tsunami Aid. She also performed at the Live 8 benefit concert in London in July 2005.</td>\n",
       "      <td>HOw much did the tour earn?</td>\n",
       "      <td>{'text': ['around $120 million'], 'answer_start': [146]}</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>57282ec84b864d1900164681</td>\n",
       "      <td>PlayStation_3</td>\n",
       "      <td>The PlayStation Store is an online virtual market available to users of Sony's PlayStation 3 (PS3) and PlayStation Portable (PSP) game consoles via the PlayStation Network. The Store offers a range of downloadable content both for purchase and available free of charge. Available content includes full games, add-on content, playable demos, themes and game and movie trailers. The service is accessible through an icon on the XMB on PS3 and PSP. The PS3 store can also be accessed on PSP via a Remote Play connection to PS3. The PSP store is also available via the PC application, Media Go. As of September 24, 2009, there have been over 600 million downloads from the PlayStation Store worldwide.</td>\n",
       "      <td>What do you click on in the PS3 interface to get to the PlayStation Store?</td>\n",
       "      <td>{'text': ['an icon'], 'answer_start': [411]}</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>5a10dee906e79900185c343d</td>\n",
       "      <td>Internet_service_provider</td>\n",
       "      <td>Just as their customers pay them for Internet access, ISPs themselves pay upstream ISPs for Internet access. An upstream ISP usually has a larger network than the contracting ISP or is able to provide the contracting ISP with access to parts of the Internet the contracting ISP by itself has no access to.</td>\n",
       "      <td>What does an upstream ISP provide for customers?</td>\n",
       "      <td>{'text': [], 'answer_start': []}</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_random_elements(datasets[\"train\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "n9qywopnIrJH"
   },
   "source": [
    "## 预处理数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "eXNLu_-nIrJI"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "32e9bcbebf484ede80740b917de286c3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7b269133bf924b0bbdb67139aaab5223",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "514e52a792244f10b33b85e820bcb189",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab1c4d48df884b4e8f90de9c6015067a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "    \n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Vl6IidfdIrJK"
   },
   "source": [
    "以下断言确保我们的 Tokenizers 使用的是 FastTokenizer（Rust 实现，速度和功能性上有一定优势）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "您可以在大模型表上查看哪种类型的模型具有可用的快速标记器，哪种类型没有。\n",
    "\n",
    "您可以直接在两个句子上调用此标记器（一个用于答案，一个用于上下文）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "a5hBlsrHIrJL",
    "outputId": "acdaa98a-a8cd-4a20-89b8-cc26437bbe90"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 2054, 2003, 2115, 2171, 1029, 102, 2026, 2171, 2003, 25353, 22144, 2378, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer(\"What is your name?\", \"My name is Sylvain.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tokenizer 进阶操作\n",
    "\n",
    "在问答预处理中的一个特定问题是如何处理非常长的文档。\n",
    "\n",
    "在其他任务中，当文档的长度超过模型最大句子长度时，我们通常会截断它们，但在这里，删除上下文的一部分可能会导致我们丢失正在寻找的答案。\n",
    "\n",
    "为了解决这个问题，我们允许数据集中的一个（长）示例生成多个输入特征，每个特征的长度都小于模型的最大长度（或我们设置的超参数）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The maximum length of a feature (question and context)\n",
    "max_length = 384 \n",
    "# The authorized overlap between two part of the context when splitting it is needed.\n",
    "doc_stride = 128 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 超出最大长度的文本数据处理\n",
    "\n",
    "下面，我们从训练集中找出一个超过最大长度（384）的文本："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, example in enumerate(datasets[\"train\"]):\n",
    "    if len(tokenizer(example[\"question\"], example[\"context\"])[\"input_ids\"]) > 384:\n",
    "        break\n",
    "# 挑选出来超过384（最大长度）的数据样例\n",
    "example = datasets[\"train\"][i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "437"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(tokenizer(example[\"question\"], example[\"context\"])[\"input_ids\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 截断上下文不保留超出部分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "384"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(tokenizer(example[\"question\"],\n",
    "              example[\"context\"],\n",
    "              max_length=max_length,\n",
    "              truncation=\"only_second\")[\"input_ids\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 关于截断的策略\n",
    "\n",
    "- 直接截断超出部分: truncation=`only_second`\n",
    "- 仅截断上下文（context），保留问题（question）：`return_overflowing_tokens=True` & 设置`stride`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_example = tokenizer(\n",
    "    example[\"question\"],\n",
    "    example[\"context\"],\n",
    "    max_length=max_length,\n",
    "    truncation=\"only_second\",\n",
    "    return_overflowing_tokens=True,\n",
    "    stride=doc_stride\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用此策略截断后，Tokenizer 将返回多个 `input_ids` 列表。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[384, 192]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[len(x) for x in tokenized_example[\"input_ids\"]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "解码两个输入特征，可以看到重叠的部分："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CLS] beyonce got married in 2008 to whom? [SEP] on april 4, 2008, beyonce married jay z. she publicly revealed their marriage in a video montage at the listening party for her third studio album, i am... sasha fierce, in manhattan's sony club on october 22, 2008. i am... sasha fierce was released on november 18, 2008 in the united states. the album formally introduces beyonce's alter ego sasha fierce, conceived during the making of her 2003 single \" crazy in love \", selling 482, 000 copies in its first week, debuting atop the billboard 200, and giving beyonce her third consecutive number - one album in the us. the album featured the number - one song \" single ladies ( put a ring on it ) \" and the top - five songs \" if i were a boy \" and \" halo \". achieving the accomplishment of becoming her longest - running hot 100 single in her career, \" halo \"'s success in the us helped beyonce attain more top - ten singles on the list than any other woman during the 2000s. it also included the successful \" sweet dreams \", and singles \" diva \", \" ego \", \" broken - hearted girl \" and \" video phone \". the music video for \" single ladies \" has been parodied and imitated around the world, spawning the \" first major dance craze \" of the internet age according to the toronto star. the video has won several awards, including best video at the 2009 mtv europe music awards, the 2009 scottish mobo awards, and the 2009 bet awards. at the 2009 mtv video music awards, the video was nominated for nine awards, ultimately winning three including video of the year. its failure to win the best female video category, which went to american country pop singer taylor swift's \" you belong with me \", led to kanye west interrupting the ceremony and beyonce [SEP]\n",
      "[CLS] beyonce got married in 2008 to whom? [SEP] single ladies \" has been parodied and imitated around the world, spawning the \" first major dance craze \" of the internet age according to the toronto star. the video has won several awards, including best video at the 2009 mtv europe music awards, the 2009 scottish mobo awards, and the 2009 bet awards. at the 2009 mtv video music awards, the video was nominated for nine awards, ultimately winning three including video of the year. its failure to win the best female video category, which went to american country pop singer taylor swift's \" you belong with me \", led to kanye west interrupting the ceremony and beyonce improvising a re - presentation of swift's award during her own acceptance speech. in march 2009, beyonce embarked on the i am... world tour, her second headlining worldwide concert tour, consisting of 108 shows, grossing $ 119. 5 million. [SEP]\n"
     ]
    }
   ],
   "source": [
    "for x in tokenized_example[\"input_ids\"][:2]:\n",
    "    print(tokenizer.decode(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 使用 offsets_mapping 获取原始的 input_ids\n",
    "\n",
    "设置 `return_offsets_mapping=True`，将使得截断分割生成的多个 input_ids 列表中的 token，通过映射保留原始文本的 input_ids。\n",
    "\n",
    "如下所示：第一个标记（[CLS]）的起始和结束字符都是（0, 0），因为它不对应问题/答案的任何部分，然后第二个标记与问题(question)的字符0到3相同."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(0, 0), (0, 7), (8, 11), (12, 19), (20, 22), (23, 27), (28, 30), (31, 35), (35, 36), (0, 0), (0, 2), (3, 8), (9, 10), (10, 11), (12, 16), (16, 17), (18, 25), (26, 33), (34, 37), (38, 39), (39, 40), (41, 44), (45, 53), (54, 62), (63, 68), (69, 77), (78, 80), (81, 82), (83, 88), (89, 93), (93, 96), (97, 99), (100, 103), (104, 113), (114, 119), (120, 123), (124, 127), (128, 133), (134, 140), (141, 146), (146, 147), (148, 149), (150, 152), (152, 153), (153, 154), (154, 155), (156, 161), (162, 168), (168, 169), (170, 172), (173, 182), (182, 183), (183, 184), (185, 189), (190, 194), (195, 197), (198, 205), (206, 208), (208, 209), (210, 214), (214, 215), (216, 217), (218, 220), (220, 221), (221, 222), (222, 223), (224, 229), (230, 236), (237, 240), (241, 249), (250, 252), (253, 261), (262, 264), (264, 265), (266, 270), (271, 273), (274, 277), (278, 284), (285, 291), (291, 292), (293, 296), (297, 302), (303, 311), (312, 322), (323, 330), (330, 331), (331, 332), (333, 338), (339, 342), (343, 348), (349, 355), (355, 356), (357, 366), (367, 373), (374, 377), (378, 384), (385, 387), (388, 391), (392, 396), (397, 403)]\n"
     ]
    }
   ],
   "source": [
    "tokenized_example = tokenizer(\n",
    "    example[\"question\"],\n",
    "    example[\"context\"],\n",
    "    max_length=max_length,\n",
    "    truncation=\"only_second\",\n",
    "    return_overflowing_tokens=True,\n",
    "    return_offsets_mapping=True,\n",
    "    stride=doc_stride\n",
    ")\n",
    "print(tokenized_example[\"offset_mapping\"][0][:100])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因此，我们可以使用这个映射来找到答案在给定特征中的起始和结束标记的位置。\n",
    "\n",
    "我们只需区分偏移的哪些部分对应于问题，哪些部分对应于上下文。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beyonce Beyonce\n"
     ]
    }
   ],
   "source": [
    "first_token_id = tokenized_example[\"input_ids\"][0][1]\n",
    "offsets = tokenized_example[\"offset_mapping\"][0][1]\n",
    "print(tokenizer.convert_ids_to_tokens([first_token_id])[0], example[\"question\"][offsets[0]:offsets[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "got got\n"
     ]
    }
   ],
   "source": [
    "second_token_id = tokenized_example[\"input_ids\"][0][2]\n",
    "offsets = tokenized_example[\"offset_mapping\"][0][2]\n",
    "print(tokenizer.convert_ids_to_tokens([second_token_id])[0], example[\"question\"][offsets[0]:offsets[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Beyonce got married in 2008 to whom?'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example[\"question\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "借助`tokenized_example`的`sequence_ids`方法，我们可以方便的区分token的来源编号：\n",
    "\n",
    "- 对于特殊标记：返回None，\n",
    "- 对于正文Token：返回句子编号（从0开始编号）。\n",
    "\n",
    "综上，现在我们可以很方便的在一个输入特征中找到答案的起始和结束 Token。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[None, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, None]\n"
     ]
    }
   ],
   "source": [
    "sequence_ids = tokenized_example.sequence_ids()\n",
    "print(sequence_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "18 19\n"
     ]
    }
   ],
   "source": [
    "answers = example[\"answers\"]\n",
    "start_char = answers[\"answer_start\"][0]\n",
    "end_char = start_char + len(answers[\"text\"][0])\n",
    "\n",
    "# 当前span在文本中的起始标记索引。\n",
    "token_start_index = 0\n",
    "while sequence_ids[token_start_index] != 1:\n",
    "    token_start_index += 1\n",
    "\n",
    "# 当前span在文本中的结束标记索引。\n",
    "token_end_index = len(tokenized_example[\"input_ids\"][0]) - 1\n",
    "while sequence_ids[token_end_index] != 1:\n",
    "    token_end_index -= 1\n",
    "\n",
    "# 检测答案是否超出span范围（如果超出范围，该特征将以CLS标记索引标记）。\n",
    "offsets = tokenized_example[\"offset_mapping\"][0]\n",
    "if (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):\n",
    "    # 将token_start_index和token_end_index移动到答案的两端。\n",
    "    # 注意：如果答案是最后一个单词，我们可以移到最后一个标记之后（边界情况）。\n",
    "    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:\n",
    "        token_start_index += 1\n",
    "    start_position = token_start_index - 1\n",
    "    while offsets[token_end_index][1] >= end_char:\n",
    "        token_end_index -= 1\n",
    "    end_position = token_end_index + 1\n",
    "    print(start_position, end_position)\n",
    "else:\n",
    "    print(\"答案不在此特征中。\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "打印检查是否准确找到了起始位置："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "jay z\n",
      "Jay Z\n"
     ]
    }
   ],
   "source": [
    "# 通过查找 offset mapping 位置，解码 context 中的答案 \n",
    "print(tokenizer.decode(tokenized_example[\"input_ids\"][0][start_position: end_position+1]))\n",
    "# 直接打印 数据集中的标准答案（answer[\"text\"])\n",
    "print(answers[\"text\"][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 关于填充的策略\n",
    "\n",
    "- 对于没有超过最大长度的文本，填充补齐长度。\n",
    "- 对于需要左侧填充的模型，交换 question 和 context 顺序"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_on_right = tokenizer.padding_side == \"right\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 整合以上所有预处理步骤\n",
    "\n",
    "让我们将所有内容整合到一个函数中，并将其应用到训练集。\n",
    "\n",
    "针对不可回答的情况（上下文过长，答案在另一个特征中），我们为开始和结束位置都设置了cls索引。\n",
    "\n",
    "如果allow_impossible_answers标志为False，我们还可以简单地从训练集中丢弃这些示例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_train_features(examples):\n",
    "    # 一些问题的左侧可能有很多空白字符，这对我们没有用，而且会导致上下文的截断失败\n",
    "    # （标记化的问题将占用大量空间）。因此，我们删除左侧的空白字符。\n",
    "    examples[\"question\"] = [q.lstrip() for q in examples[\"question\"]]\n",
    "\n",
    "    # 使用截断和填充对我们的示例进行标记化，但保留溢出部分，使用步幅（stride）。\n",
    "    # 当上下文很长时，这会导致一个示例可能提供多个特征，其中每个特征的上下文都与前一个特征的上下文有一些重叠。\n",
    "    tokenized_examples = tokenizer(\n",
    "        examples[\"question\" if pad_on_right else \"context\"],\n",
    "        examples[\"context\" if pad_on_right else \"question\"],\n",
    "        truncation=\"only_second\" if pad_on_right else \"only_first\",\n",
    "        max_length=max_length,\n",
    "        stride=doc_stride,\n",
    "        return_overflowing_tokens=True,\n",
    "        return_offsets_mapping=True,\n",
    "        padding=\"max_length\",\n",
    "    )\n",
    "\n",
    "    # 由于一个示例可能给我们提供多个特征（如果它具有很长的上下文），我们需要一个从特征到其对应示例的映射。这个键就提供了这个映射关系。\n",
    "    sample_mapping = tokenized_examples.pop(\"overflow_to_sample_mapping\")\n",
    "    # 偏移映射将为我们提供从令牌到原始上下文中的字符位置的映射。这将帮助我们计算开始位置和结束位置。\n",
    "    offset_mapping = tokenized_examples.pop(\"offset_mapping\")\n",
    "\n",
    "    # 让我们为这些示例进行标记！\n",
    "    tokenized_examples[\"start_positions\"] = []\n",
    "    tokenized_examples[\"end_positions\"] = []\n",
    "\n",
    "    for i, offsets in enumerate(offset_mapping):\n",
    "        # 我们将使用 CLS 特殊 token 的索引来标记不可能的答案。\n",
    "        input_ids = tokenized_examples[\"input_ids\"][i]\n",
    "        cls_index = input_ids.index(tokenizer.cls_token_id)\n",
    "\n",
    "        # 获取与该示例对应的序列（以了解上下文和问题是什么）。\n",
    "        sequence_ids = tokenized_examples.sequence_ids(i)\n",
    "\n",
    "        # 一个示例可以提供多个跨度，这是包含此文本跨度的示例的索引。\n",
    "        sample_index = sample_mapping[i]\n",
    "        answers = examples[\"answers\"][sample_index]\n",
    "        # 如果没有给出答案，则将cls_index设置为答案。\n",
    "        if len(answers[\"answer_start\"]) == 0:\n",
    "            tokenized_examples[\"start_positions\"].append(cls_index)\n",
    "            tokenized_examples[\"end_positions\"].append(cls_index)\n",
    "        else:\n",
    "            # 答案在文本中的开始和结束字符索引。\n",
    "            start_char = answers[\"answer_start\"][0]\n",
    "            end_char = start_char + len(answers[\"text\"][0])\n",
    "\n",
    "            # 当前跨度在文本中的开始令牌索引。\n",
    "            token_start_index = 0\n",
    "            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):\n",
    "                token_start_index += 1\n",
    "\n",
    "            # 当前跨度在文本中的结束令牌索引。\n",
    "            token_end_index = len(input_ids) - 1\n",
    "            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):\n",
    "                token_end_index -= 1\n",
    "\n",
    "            # 检测答案是否超出跨度（在这种情况下，该特征的标签将使用CLS索引）。\n",
    "            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):\n",
    "                tokenized_examples[\"start_positions\"].append(cls_index)\n",
    "                tokenized_examples[\"end_positions\"].append(cls_index)\n",
    "            else:\n",
    "                # 否则，将token_start_index和token_end_index移到答案的两端。\n",
    "                # 注意：如果答案是最后一个单词（边缘情况），我们可以在最后一个偏移之后继续。\n",
    "                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:\n",
    "                    token_start_index += 1\n",
    "                tokenized_examples[\"start_positions\"].append(token_start_index - 1)\n",
    "                while offsets[token_end_index][1] >= end_char:\n",
    "                    token_end_index -= 1\n",
    "                tokenized_examples[\"end_positions\"].append(token_end_index + 1)\n",
    "\n",
    "    return tokenized_examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zS-6iXTkIrJT"
   },
   "source": [
    "#### datasets.map 的进阶使用\n",
    "\n",
    "使用 `datasets.map` 方法将 `prepare_train_features` 应用于所有训练、验证和测试数据：\n",
    "\n",
    "- batched: 批量处理数据。\n",
    "- remove_columns: 因为预处理更改了样本的数量，所以在应用它时需要删除旧列。\n",
    "- load_from_cache_file：是否使用datasets库的自动缓存\n",
    "\n",
    "datasets 库针对大规模数据，实现了高效缓存机制，能够自动检测传递给 map 的函数是否已更改（因此需要不使用缓存数据）。如果在调用 map 时设置 `load_from_cache_file=False`，可以强制重新应用预处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "id": "DDtsaJeVIrJT",
    "outputId": "aa4734bf-4ef5-4437-9948-2c16363da719"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ebef1473873c4f5cba7f0369ee346a48",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/130319 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ad847a7eaca432881896fb59e7a9ae5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/11873 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokenized_datasets = datasets.map(prepare_train_features,\n",
    "                                  batched=True,\n",
    "                                  remove_columns=datasets[\"train\"].column_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "545PP3o8IrJV"
   },
   "source": [
    "## 微调模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FBiW8UpKIrJW"
   },
   "source": [
    "现在我们的数据已经准备好用于训练，我们可以下载预训练模型并进行微调。\n",
    "\n",
    "由于我们的任务是问答，我们使用 `AutoModelForQuestionAnswering` 类。(对比 Yelp 评论打分使用的是 `AutoModelForSequenceClassification` 类）\n",
    "\n",
    "警告通知我们正在丢弃一些权重（`vocab_transform` 和 `vocab_layer_norm` 层），并随机初始化其他一些权重（`pre_classifier` 和 `classifier` 层）。在微调模型情况下是绝对正常的，因为我们正在删除用于预训练模型的掩码语言建模任务的头部，并用一个新的头部替换它，对于这个新头部，我们没有预训练的权重，所以库会警告我们在用它进行推理之前应该对这个模型进行微调，而这正是我们要做的事情。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "id": "TlqNaB8jIrJW",
    "outputId": "84916cf3-6e6c-47f3-d081-032ec30a4132"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "06a19cba1c2c4456876b64133668d7d5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_N8urzhyIrJY"
   },
   "source": [
    "#### 训练超参数（TrainingArguments）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "id": "Bliy8zgjIrJY"
   },
   "outputs": [],
   "source": [
    "batch_size=64\n",
    "model_dir = f\"models/{model_checkpoint}-finetuned-squad\"\n",
    "\n",
    "args = TrainingArguments(\n",
    "    output_dir=model_dir,\n",
    "    evaluation_strategy = \"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    num_train_epochs=3,\n",
    "    weight_decay=0.01,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Data Collator（数据整理器）\n",
    "\n",
    "数据整理器将训练数据整理为批次数据，用于模型训练时的批次处理。本教程使用默认的 `default_data_collator`。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import default_data_collator\n",
    "\n",
    "data_collator = default_data_collator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rXuFTAzDIrJe"
   },
   "source": [
    "### 实例化训练器（Trainer）\n",
    "\n",
    "为了减少训练时间（需要大量算力支持），我们不在本教程的训练模型过程中计算模型评估指标。\n",
    "\n",
    "而是训练完成后，再独立进行模型评估。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "id": "imY1oC3SIrJf"
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model,\n",
    "    args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### GPU 使用情况\n",
    "\n",
    "训练数据与模型配置：\n",
    "\n",
    "- SQUAD v1.1\n",
    "- model_checkpoint = \"distilbert-base-uncased\"\n",
    "- batch_size = 64\n",
    "\n",
    "NVIDIA GPU 使用情况：\n",
    "\n",
    "```shell\n",
    "Wed Mar 27 18:40:16 2024\r\n",
    "+-----------------------------------------------------------------------------------------+\r\n",
    "| NVIDIA-SMI 551.31                 Driver Version: 551.31         CUDA Version: 12.4     |\r\n",
    "|-----------------------------------------+------------------------+----------------------+\r\n",
    "| GPU  Name                     TCC/WDDM  | Bus-Id          Disp.A | Volatile Uncorr. ECC |\r\n",
    "| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |\r\n",
    "|                                         |                        |               MIG M. |\r\n",
    "|=========================================+========================+======================|\r\n",
    "|   0  NVIDIA GeForce RTX 4080 ...  WDDM  |   00000000:03:00.0 Off |                  N/A |\r\n",
    "| 30%   55C    P2            259W /  320W |   15502MiB /  16376MiB |     84%      Default |\r\n",
    "|                                         |                        |                  N/A |\r\n",
    "+-----------------------------------------+------------------------+----------------------+\r\n",
    "\r\n",
    "+-----------------------------------------------------------------------------------------+\r\n",
    "| Processes:                                                                              |\r\n",
    "|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |\r\n",
    "|        ID   ID                                                               Usage      |\r\n",
    "|=========================================================================================|\r\n",
    "|    0   N/A  N/A      5172    C+G   ...5n1h2txyewy\\ShellExperienceHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A     15540    C+G   C:\\Windows\\explorer.exe                     N/A      |\r\n",
    "|    0   N/A  N/A     16712    C+G   ....Search_cw5n1h2txyewy\\SearchApp.exe      N/A      |\r\n",
    "|    0   N/A  N/A     17380    C+G   ...CBS_cw5n1h2txyewy\\TextInputHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A     26960    C+G   ...CBS_cw5n1h2txyewy\\TextInputHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A     29408    C+G   ...\\Docker\\frontend\\Docker Desktop.exe      N/A      |\r\n",
    "|    0   N/A  N/A     29992    C+G   C:\\Windows\\explorer.exe                     N/A      |\r\n",
    "|    0   N/A  N/A     30992    C+G   C:\\Windows\\explorer.exe                     N/A      |\r\n",
    "|    0   N/A  N/A     39256    C+G   ...\\PyCharm 2023.3.3\\bin\\pycharm64.exe      N/A      |\r\n",
    "|    0   N/A  N/A     50520    C+G   ...CBS_cw5n1h2txyewy\\TextInputHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A     51448    C+G   ...oogle\\Chrome\\Application\\chrome.exe      N/A      |\r\n",
    "|    0   N/A  N/A     52148    C+G   ...5n1h2txyewy\\ShellExperienceHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A     56100    C+G   ...oogle\\Chrome\\Application\\chrome.exe      N/A      |\r\n",
    "|    0   N/A  N/A     61012    C+G   D:\\Apifox\\Apifox.exe                        N/A      |\r\n",
    "|    0   N/A  N/A     61232    C+G   ....Search_cw5n1h2txyewy\\SearchApp.exe      N/A      |\r\n",
    "|    0   N/A  N/A     62044    C+G   ...5n1h2txyewy\\ShellExperienceHost.exe      N/A      |\r\n",
    "|    0   N/A  N/A     73068    C+G   ...oaming\\360se6\\Application\\360se.exe      N/A      |\r\n",
    "+-----------------------------------------------------------------------------------------+\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "id": "uNx5pyRlIrJh",
    "outputId": "077e661e-d36c-469b-89b8-7ff7f73541ec"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='6177' max='6177' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [6177/6177 47:27, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.404100</td>\n",
       "      <td>1.294796</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.144500</td>\n",
       "      <td>1.264730</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.007800</td>\n",
       "      <td>1.309779</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=6177, training_loss=1.3159433704219976, metrics={'train_runtime': 2848.8102, 'train_samples_per_second': 138.746, 'train_steps_per_second': 2.168, 'total_flos': 3.873165421863629e+16, 'train_loss': 1.3159433704219976, 'epoch': 3.0})"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练完成后，第一时间保存模型权重文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_to_save = trainer.save_model(model_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型评估"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**评估模型输出需要一些额外的处理：将模型的预测映射回上下文的部分。**\n",
    "\n",
    "模型直接输出的是预测答案的`起始位置`和`结束位置`的**logits**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['loss', 'start_logits', 'end_logits'])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "for batch in trainer.get_eval_dataloader():\n",
    "    break\n",
    "batch = {k: v.to(trainer.args.device) for k, v in batch.items()}\n",
    "with torch.no_grad():\n",
    "    output = trainer.model(**batch)\n",
    "output.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模型的输出是一个类似字典的对象，其中包含损失（因为我们提供了标签），以及起始和结束logits。我们不需要损失来进行预测，让我们看一下logits："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 384]), torch.Size([64, 384]))"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.start_logits.shape, output.end_logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 49,  37,  72,  80, 157,  17,  44,  99, 152, 202, 119,  52,  22,  31,\n",
       "           0, 120, 252,  94,  90,   0,   0,  56,  88, 152,  23,   0,   0, 138,\n",
       "           0,   0, 143,   0,  72,  32,   0,   0,  93,   0,  93, 105,  40,  13,\n",
       "          33,  55,  99,  29,  80,   0,  28,  29,   0,   0, 139,  70,  49,  17,\n",
       "          47,  62,  74,  56,  16,   0,   0,  12], device='cuda:0'),\n",
       " tensor([ 49,  40,  76,  81, 157,  19,  44, 101, 153, 204, 120,  52,  26,  33,\n",
       "           0, 121, 252,  97,  91,   0,   0,  56,  94,   0,  24,   0,   0, 142,\n",
       "           0,  52, 149,   0,  72,   0,   0,   0,  94,  27,  94, 112,  58,  15,\n",
       "          50,  56, 104,  30,  81,   0,  29,  30,   0,   0, 141,  71,  51,   0,\n",
       "          49,  63,  82,  52,  17,   0,   0,  15], device='cuda:0'))"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.start_logits.argmax(dim=-1), output.end_logits.argmax(dim=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 如何从模型输出的位置 logit 组合成答案\n",
    "\n",
    "我们有每个特征和每个标记的logit。在每个特征中为每个标记预测答案最明显的方法是，将起始logits的最大索引作为起始位置，将结束logits的最大索引作为结束位置。\n",
    "\n",
    "在许多情况下这种方式效果很好，但是如果此预测给出了不可能的结果该怎么办？比如：起始位置可能大于结束位置，或者指向问题中的文本片段而不是答案。在这种情况下，我们可能希望查看第二好的预测，看它是否给出了一个可能的答案，并选择它。\n",
    "\n",
    "选择第二好的答案并不像选择最佳答案那么容易：\n",
    "- 它是起始logits中第二佳索引与结束logits中最佳索引吗？\n",
    "- 还是起始logits中最佳索引与结束logits中第二佳索引？\n",
    "- 如果第二好的答案也不可能，那么对于第三好的答案，情况会更加棘手。\n",
    "\n",
    "为了对答案进行分类，\n",
    "1. 将使用通过添加起始和结束logits获得的分数\n",
    "1. 设计一个名为`n_best_size`的超参数，限制不对所有可能的答案进行排序。\n",
    "1. 我们将选择起始和结束logits中的最佳索引，并收集这些预测的所有答案。\n",
    "1. 在检查每一个是否有效后，我们将按照其分数对它们进行排序，并保留最佳的答案。\n",
    "\n",
    "以下是我们如何在批次中的第一个特征上执行此操作的示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_best_size = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "start_logits = output.start_logits[0].cpu().numpy()\n",
    "end_logits = output.end_logits[0].cpu().numpy()\n",
    "\n",
    "# 获取最佳的起始和结束位置的索引：\n",
    "start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()\n",
    "end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\n",
    "\n",
    "valid_answers = []\n",
    "\n",
    "# 遍历起始位置和结束位置的索引组合\n",
    "for start_index in start_indexes:\n",
    "    for end_index in end_indexes:\n",
    "        if start_index <= end_index:  # 需要进一步测试以检查答案是否在上下文中\n",
    "            valid_answers.append(\n",
    "                {\n",
    "                    \"score\": start_logits[start_index] + end_logits[end_index],\n",
    "                    \"text\": \"\"  # 我们需要找到一种方法来获取与上下文中答案对应的原始子字符串\n",
    "                }\n",
    "            )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "然后，我们可以根据它们的得分对`valid_answers`进行排序，并仅保留最佳答案。唯一剩下的问题是如何检查给定的跨度是否在上下文中（而不是问题中），以及如何获取其中的文本。为此，我们需要向我们的验证特征添加两个内容：\n",
    "\n",
    "- 生成该特征的示例的ID（因为每个示例可以生成多个特征，如前所示）；\n",
    "- 偏移映射，它将为我们提供从标记索引到上下文中字符位置的映射。\n",
    "\n",
    "这就是为什么我们将使用以下函数稍微不同于`prepare_train_features`来重新处理验证集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_validation_features(examples):\n",
    "    # 一些问题的左侧有很多空白，这些空白并不有用且会导致上下文截断失败（分词后的问题会占用很多空间）。\n",
    "    # 因此我们移除这些左侧空白\n",
    "    examples[\"question\"] = [q.lstrip() for q in examples[\"question\"]]\n",
    "\n",
    "    # 使用截断和可能的填充对我们的示例进行分词，但使用步长保留溢出的令牌。这导致一个长上下文的示例可能产生\n",
    "    # 几个特征，每个特征的上下文都会稍微与前一个特征的上下文重叠。\n",
    "    tokenized_examples = tokenizer(\n",
    "        examples[\"question\" if pad_on_right else \"context\"],\n",
    "        examples[\"context\" if pad_on_right else \"question\"],\n",
    "        truncation=\"only_second\" if pad_on_right else \"only_first\",\n",
    "        max_length=max_length,\n",
    "        stride=doc_stride,\n",
    "        return_overflowing_tokens=True,\n",
    "        return_offsets_mapping=True,\n",
    "        padding=\"max_length\",\n",
    "    )\n",
    "\n",
    "    # 由于一个示例在上下文很长时可能会产生几个特征，我们需要一个从特征映射到其对应示例的映射。这个键就是为了这个目的。\n",
    "    sample_mapping = tokenized_examples.pop(\"overflow_to_sample_mapping\")\n",
    "\n",
    "    # 我们保留产生这个特征的示例ID，并且会存储偏移映射。\n",
    "    tokenized_examples[\"example_id\"] = []\n",
    "\n",
    "    for i in range(len(tokenized_examples[\"input_ids\"])):\n",
    "        # 获取与该示例对应的序列（以了解哪些是上下文，哪些是问题）。\n",
    "        sequence_ids = tokenized_examples.sequence_ids(i)\n",
    "        context_index = 1 if pad_on_right else 0\n",
    "\n",
    "        # 一个示例可以产生几个文本段，这里是包含该文本段的示例的索引。\n",
    "        sample_index = sample_mapping[i]\n",
    "        tokenized_examples[\"example_id\"].append(examples[\"id\"][sample_index])\n",
    "\n",
    "        # 将不属于上下文的偏移映射设置为None，以便容易确定一个令牌位置是否属于上下文。\n",
    "        tokenized_examples[\"offset_mapping\"][i] = [\n",
    "            (o if sequence_ids[k] == context_index else None)\n",
    "            for k, o in enumerate(tokenized_examples[\"offset_mapping\"][i])\n",
    "        ]\n",
    "\n",
    "    return tokenized_examples\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将`prepare_validation_features`应用到整个验证集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "855f7232322f4bb3b4db100649beba1f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/11873 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "validation_features = datasets[\"validation\"].map(\n",
    "    prepare_validation_features,\n",
    "    batched=True,\n",
    "    remove_columns=datasets[\"validation\"].column_names\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can grab the predictions for all features by using the `Trainer.predict` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "raw_predictions = trainer.predict(validation_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Trainer`会隐藏模型不使用的列（在这里是`example_id`和`offset_mapping`，我们需要它们进行后处理），所以我们需要将它们重新设置回来："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_features.set_format(type=validation_features.format[\"type\"], columns=list(validation_features.features.keys()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，我们可以改进之前的测试：\n",
    "\n",
    "由于在偏移映射中，当它对应于问题的一部分时，我们将其设置为None，因此可以轻松检查答案是否完全在上下文中。我们还可以从考虑中排除非常长的答案（可以调整的超参数）。\n",
    "\n",
    "展开说下具体实现：\n",
    "- 首先从模型输出中获取起始和结束的逻辑值（logits），这些值表明答案在文本中可能开始和结束的位置。\n",
    "- 然后，它使用偏移映射（offset_mapping）来找到这些逻辑值在原始文本中的具体位置。\n",
    "- 接下来，代码遍历可能的开始和结束索引组合，排除那些不在上下文范围内或长度不合适的答案。\n",
    "- 对于有效的答案，它计算出一个分数（基于开始和结束逻辑值的和），并将答案及其分数存储起来。\n",
    "- 最后，它根据分数对答案进行排序，并返回得分最高的几个答案。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_answer_length = 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 13.97304, 'text': 'France'},\n",
       " {'score': 9.775538,\n",
       "  'text': 'France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway'},\n",
       " {'score': 8.809256,\n",
       "  'text': 'France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland'},\n",
       " {'score': 8.508085,\n",
       "  'text': 'France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark'},\n",
       " {'score': 8.003988, 'text': 'France.'},\n",
       " {'score': 7.6764975, 'text': 'in France'},\n",
       " {'score': 6.162384, 'text': 'a region in France'},\n",
       " {'score': 5.856285, 'text': 'Normandy, a region in France'},\n",
       " {'score': 5.549666,\n",
       "  'text': 'French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France'},\n",
       " {'score': 5.3652115, 'text': 'Iceland and Norway'},\n",
       " {'score': 5.16555, 'text': 'Denmark, Iceland and Norway'},\n",
       " {'score': 4.7449675, 'text': 'region in France'},\n",
       " {'score': 4.3989286, 'text': 'Iceland'},\n",
       " {'score': 4.1992674, 'text': 'Denmark, Iceland'},\n",
       " {'score': 3.8980973, 'text': 'Denmark'},\n",
       " {'score': 3.791533, 'text': 'Norway'},\n",
       " {'score': 3.7146401, 'text': 'France. They were descended from Norse'},\n",
       " {'score': 3.478996,\n",
       "  'text': 'in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway'},\n",
       " {'score': 3.3824978,\n",
       "  'text': 'Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France'},\n",
       " {'score': 3.2616644,\n",
       "  'text': 'France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark,'}]"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "start_logits = output.start_logits[0].cpu().numpy()\n",
    "end_logits = output.end_logits[0].cpu().numpy()\n",
    "offset_mapping = validation_features[0][\"offset_mapping\"]\n",
    "\n",
    "# 第一个特征来自第一个示例。对于更一般的情况，我们需要将example_id匹配到一个示例索引\n",
    "context = datasets[\"validation\"][0][\"context\"]\n",
    "\n",
    "# 收集最佳开始/结束逻辑的索引：\n",
    "start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()\n",
    "end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\n",
    "valid_answers = []\n",
    "for start_index in start_indexes:\n",
    "    for end_index in end_indexes:\n",
    "        # 不考虑超出范围的答案，原因是索引超出范围或对应于输入ID的部分不在上下文中。\n",
    "        if (\n",
    "            start_index >= len(offset_mapping)\n",
    "            or end_index >= len(offset_mapping)\n",
    "            or offset_mapping[start_index] is None\n",
    "            or offset_mapping[end_index] is None\n",
    "        ):\n",
    "            continue\n",
    "        # 不考虑长度小于0或大于max_answer_length的答案。\n",
    "        if end_index < start_index or end_index - start_index + 1 > max_answer_length:\n",
    "            continue\n",
    "        if start_index <= end_index: # 我们需要细化这个测试，以检查答案是否在上下文中\n",
    "            start_char = offset_mapping[start_index][0]\n",
    "            end_char = offset_mapping[end_index][1]\n",
    "            valid_answers.append(\n",
    "                {\n",
    "                    \"score\": start_logits[start_index] + end_logits[end_index],\n",
    "                    \"text\": context[start_char: end_char]\n",
    "                }\n",
    "            )\n",
    "\n",
    "valid_answers = sorted(valid_answers, key=lambda x: x[\"score\"], reverse=True)[:n_best_size]\n",
    "valid_answers\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "打印比较模型输出和标准答案（Ground-truth）是否一致:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': ['France', 'France', 'France', 'France'],\n",
       " 'answer_start': [159, 159, 159, 159]}"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets[\"validation\"][0][\"answers\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**模型最高概率的输出与标准答案一致**\n",
    "\n",
    "正如上面的代码所示，这在第一个特征上很容易，因为我们知道它来自第一个示例。\n",
    "\n",
    "对于其他特征，我们需要建立一个示例与其对应特征的映射关系。\n",
    "\n",
    "此外，由于一个示例可以生成多个特征，我们需要将由给定示例生成的所有特征中的所有答案汇集在一起，然后选择最佳答案。\n",
    "\n",
    "下面的代码构建了一个示例索引到其对应特征索引的映射关系："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "\n",
    "examples = datasets[\"validation\"]\n",
    "features = validation_features\n",
    "\n",
    "example_id_to_index = {k: i for i, k in enumerate(examples[\"id\"])}\n",
    "features_per_example = collections.defaultdict(list)\n",
    "for i, feature in enumerate(features):\n",
    "    features_per_example[example_id_to_index[feature[\"example_id\"]]].append(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当`squad_v2 = True`时，有一定概率出现不可能的答案（impossible answer)。\n",
    "\n",
    "上面的代码仅保留在上下文中的答案，我们还需要获取不可能答案的分数（其起始和结束索引对应于CLS标记的索引）。\n",
    "\n",
    "当一个示例生成多个特征时，我们必须在所有特征中的不可能答案都预测出现不可能答案时（因为一个特征可能之所以能够预测出不可能答案，是因为答案不在它可以访问的上下文部分），这就是为什么一个示例中不可能答案的分数是该示例生成的每个特征中的不可能答案的分数的最小值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):\n",
    "    all_start_logits, all_end_logits = raw_predictions\n",
    "    # 构建一个从示例到其对应特征的映射。\n",
    "    example_id_to_index = {k: i for i, k in enumerate(examples[\"id\"])}\n",
    "    features_per_example = collections.defaultdict(list)\n",
    "    for i, feature in enumerate(features):\n",
    "        features_per_example[example_id_to_index[feature[\"example_id\"]]].append(i)\n",
    "\n",
    "    # 我们需要填充的字典。\n",
    "    predictions = collections.OrderedDict()\n",
    "\n",
    "    # 日志记录。\n",
    "    print(f\"正在后处理 {len(examples)} 个示例的预测，这些预测分散在 {len(features)} 个特征中。\")\n",
    "\n",
    "    # 遍历所有示例！\n",
    "    for example_index, example in enumerate(tqdm(examples)):\n",
    "        # 这些是与当前示例关联的特征的索引。\n",
    "        feature_indices = features_per_example[example_index]\n",
    "\n",
    "        min_null_score = None # 仅在squad_v2为True时使用。\n",
    "        valid_answers = []\n",
    "        \n",
    "        context = example[\"context\"]\n",
    "        # 遍历与当前示例关联的所有特征。\n",
    "        for feature_index in feature_indices:\n",
    "            # 我们获取模型对这个特征的预测。\n",
    "            start_logits = all_start_logits[feature_index]\n",
    "            end_logits = all_end_logits[feature_index]\n",
    "            # 这将允许我们将logits中的某些位置映射到原始上下文中的文本跨度。\n",
    "            offset_mapping = features[feature_index][\"offset_mapping\"]\n",
    "\n",
    "            # 更新最小空预测。\n",
    "            cls_index = features[feature_index][\"input_ids\"].index(tokenizer.cls_token_id)\n",
    "            feature_null_score = start_logits[cls_index] + end_logits[cls_index]\n",
    "            if min_null_score is None or min_null_score < feature_null_score:\n",
    "                min_null_score = feature_null_score\n",
    "\n",
    "            # 浏览所有的最佳开始和结束logits，为 `n_best_size` 个最佳选择。\n",
    "            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()\n",
    "            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\n",
    "            for start_index in start_indexes:\n",
    "                for end_index in end_indexes:\n",
    "                    # 不考虑超出范围的答案，原因是索引超出范围或对应于输入ID的部分不在上下文中。\n",
    "                    if (\n",
    "                        start_index >= len(offset_mapping)\n",
    "                        or end_index >= len(offset_mapping)\n",
    "                        or offset_mapping[start_index] is None\n",
    "                        or offset_mapping[end_index] is None\n",
    "                    ):\n",
    "                        continue\n",
    "                    # 不考虑长度小于0或大于max_answer_length的答案。\n",
    "                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:\n",
    "                        continue\n",
    "\n",
    "                    start_char = offset_mapping[start_index][0]\n",
    "                    end_char = offset_mapping[end_index][1]\n",
    "                    valid_answers.append(\n",
    "                        {\n",
    "                            \"score\": start_logits[start_index] + end_logits[end_index],\n",
    "                            \"text\": context[start_char: end_char]\n",
    "                        }\n",
    "                    )\n",
    "        \n",
    "        if len(valid_answers) > 0:\n",
    "            best_answer = sorted(valid_answers, key=lambda x: x[\"score\"], reverse=True)[0]\n",
    "        else:\n",
    "            # 在极少数情况下我们没有一个非空预测，我们创建一个假预测以避免失败。\n",
    "            best_answer = {\"text\": \"\", \"score\": 0.0}\n",
    "        \n",
    "        # 选择我们的最终答案：最佳答案或空答案（仅适用于squad_v2）\n",
    "        if not squad_v2:\n",
    "            predictions[example[\"id\"]] = best_answer[\"text\"]\n",
    "        else:\n",
    "            answer = best_answer[\"text\"] if best_answer[\"score\"] > min_null_score else \"\"\n",
    "            predictions[example[\"id\"]] = answer\n",
    "\n",
    "    return predictions\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在原始结果上应用后处理问答结果："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "正在后处理 11873 个示例的预测，这些预测分散在 12134 个特征中。\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1cc7227e258245ba9108844c87be1ca4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/11873 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "final_predictions = postprocess_qa_predictions(datasets[\"validation\"], validation_features, raw_predictions.predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用 `datasets.load_metric` 中加载 `SQuAD v2` 的评估指标"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_11463/2330875496.py:3: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
      "  metric = load_metric(\"squad_v2\" if squad_v2 else \"squad\")\n",
      "/root/anaconda3/envs/llm/lib/python3.11/site-packages/datasets/load.py:752: FutureWarning: The repository for squad_v2 contains custom code which must be executed to correctly load the metric. You can inspect the repository content at https://raw.githubusercontent.com/huggingface/datasets/2.16.1/metrics/squad_v2/squad_v2.py\n",
      "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
      "Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "385c7328698f4815bea90b111590ac7b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading builder script:   0%|          | 0.00/2.25k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "234eea32bbb64f5f8f55e179ac2b2d55",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading extra modules:   0%|          | 0.00/3.19k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_metric\n",
    "\n",
    "metric = load_metric(\"squad_v2\" if squad_v2 else \"squad\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们可以调用上面定义的函数进行评估。\n",
    "\n",
    "只需稍微调整一下预测和标签的格式，因为它期望的是一系列字典而不是一个大字典。\n",
    "\n",
    "在使用`squad_v2`数据集时，我们还需要设置`no_answer_probability`参数（我们在这里将其设置为0.0，因为如果我们选择了答案，我们已经将答案设置为空）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'exact': 61.8798955613577,\n",
       " 'f1': 65.47383068097734,\n",
       " 'total': 11873,\n",
       " 'HasAns_exact': 63.30971659919028,\n",
       " 'HasAns_f1': 70.50789333253066,\n",
       " 'HasAns_total': 5928,\n",
       " 'NoAns_exact': 60.45416316232128,\n",
       " 'NoAns_f1': 60.45416316232128,\n",
       " 'NoAns_total': 5945,\n",
       " 'best_exact': 61.8798955613577,\n",
       " 'best_exact_thresh': 0.0,\n",
       " 'best_f1': 65.47383068097753,\n",
       " 'best_f1_thresh': 0.0}"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if squad_v2:\n",
    "    formatted_predictions = [{\"id\": k, \"prediction_text\": v, \"no_answer_probability\": 0.0} for k, v in final_predictions.items()]\n",
    "else:\n",
    "    formatted_predictions = [{\"id\": k, \"prediction_text\": v} for k, v in final_predictions.items()]\n",
    "references = [{\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in datasets[\"validation\"]]\n",
    "metric.compute(predictions=formatted_predictions, references=references)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Homework：加载本地保存的模型，进行评估和再训练更高的 F1 Score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_model = AutoModelForQuestionAnswering.from_pretrained(model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_trainer = Trainer(\n",
    "    trained_model,\n",
    "    args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_trainer.train()"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Question Answering on SQUAD",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
